// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This is a generated file. DO NOT EDIT!

#include <vector>
#include <string>

#include "mace/proto/mace.pb.h"
#include "mace/public/mace.h"
#include "mace/port/env.h"
#include "mace/utils/logging.h"

namespace mace {
namespace {

void UpdateOp(mace::OperatorDef *op,
              const std::string &name,
              const std::string &type,
              const std::vector<std::string> &inputs,
              const std::vector<std::string> &outputs,
              const std::vector<mace::DataType> &output_types,
              uint32_t node_id,
              const std::vector<int> &mem_ids) {
  op->set_name(name);
  op->set_type(type);
  op->set_node_id(node_id);

  op->mutable_input()->Reserve(inputs.size());
  for (auto input : inputs) {
    op->add_input(input);
  }
  op->mutable_output()->Reserve(outputs.size());
  for (auto output : outputs) {
    op->add_output(output);
  }
  op->mutable_output_type()->Reserve(output_types.size());
  for (auto output_type : output_types) {
    op->add_output_type(output_type);
  }
  op->mutable_mem_id()->Reserve(mem_ids.size());
  for (auto mem_id : mem_ids) {
    op->add_mem_id(mem_id);
  }
}

}  // namespace
}  // namespace mace

namespace mace {
namespace {{tag}} {

{% for i in range(start, end) %}

void CreateOperator{{i}}(mace::OperatorDef *op) {
  MACE_LATENCY_LOGGER(2, "Create operator",
    {{ net.op[i].name|default("undefined")|tojson }});

  mace::Argument *arg = nullptr;
  op->mutable_arg()->Reserve({{ net.op[i].arg|length }});
  {% for arg in net.op[i].arg %}

  arg = op->add_arg();
  arg->set_name({{ arg.name|tojson }});

  {%- if arg.HasField('f') %}
  arg->set_f({{ arg.f }});
  {%- endif %}
  {%- if arg.HasField('i') %}
  arg->set_i({{ arg.i }});
  {%- endif %}
  {%- if arg.HasField('s') %}
  arg->set_s({{ arg.s.decode('utf-8')|tojson }});
  {%- endif %}

  {% if arg.floats|length > 0 %}
    arg->mutable_floats()->Reserve({{ arg.floats|length }});
    {% for float_value in arg.floats %}
    arg->add_floats({{ float_value }});
    {% endfor %}
  {% endif %}

  {% if arg.ints|length > 0 %}
    arg->mutable_ints()->Reserve({{ arg.ints|length }});
    {% for int_value in arg.ints %}
    arg->add_ints({{ int_value }});
    {% endfor %}
  {% endif %}

  {% endfor %}

  {% if net.op[i].output_shape|length > 0 %}
  op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
  {% for shape in net.op[i].output_shape %}
  {% if shape.dims|length > 0 %}
  {
      mace::OutputShape *output_shape = op->add_output_shape();
      output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
      {% for dim in shape.dims %}
      output_shape->add_dims({{ dim }});
      {% endfor %}
  }
  {% endif %}

  {% endfor %}
  {% endif %}

  std::vector<int> output_types_int({ {{ net.op[i].output_type | join(', ') }} });
  std::vector<mace::DataType> output_types({{ net.op[i].output_type | length }});
  for (int k = 0; k < {{ net.op[i].output_type | length }}; ++k) {
    output_types[k] = static_cast<mace::DataType>(output_types_int[k]);
  }
  UpdateOp(op, {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}},
          { {{ net.op[i].input|stringfy }} },
          { {{ net.op[i].output|stringfy }} },
          output_types,
          {{ net.op[i].node_id }},
          { {{ net.op[i].mem_id | join(', ') }} });


  op->mutable_quantize_info()->Reserve({{ net.op[i].quantize_info | length }});
  {% for j in range(net.op[i].quantize_info|length) %}
    auto quantize_info{{j}} = op->add_quantize_info();
    quantize_info{{j}}->set_scale({{ net.op[i].quantize_info[j].scale }});
    quantize_info{{j}}->set_zero_point({{ net.op[i].quantize_info[j].zero_point }});
    quantize_info{{j}}->set_minval({{ net.op[i].quantize_info[j].minval }});
    quantize_info{{j}}->set_maxval({{ net.op[i].quantize_info[j].maxval }});
  {% endfor %}

  {% if device == 3 or device == 4 %}
    op->set_padding({{ net.op[i].padding }});
    {% if net.op[i].node_input | length > 0 %}
    std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });
    std::vector<int> input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} });

    mace::NodeInput *node_input = nullptr;
    op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }});
    for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) {
      node_input = op->add_node_input();
      node_input->set_node_id(input_node_ids[i]);
      node_input->set_output_port(input_output_ports[i]);
    }
    {% endif %}

    {% if net.op[i].out_max_byte_size | length > 0 %}
    std::vector<int> out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }};
    op->mutable_out_max_byte_size()->Reserve({{ net.op[i].out_max_byte_size|length }});
    for (size_t i = 0; i < {{ net.op[i].out_max_byte_size|length }}; ++i) {
      op->add_out_max_byte_size(out_max_byte_sizes[i]);
    }
    {% endif %}
  {% endif %}

  {% if device == 5 %}
    {% if net.op[i].node_input | length > 0 %}
    std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });

    mace::NodeInput *node_input = nullptr;
    op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }});
    for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) {
      node_input = op->add_node_input();
      node_input->set_node_id(input_node_ids[i]);
    }
    {% endif %}

  {% endif %}
}

{% endfor %}

}  // namespace {{tag}}
}  // namespace mace
